#include "Mesh.h"
#include "Element.h"
#include <iostream>
#include <string>
#include <cmath>
#include <fstream>
#include "regularMultigridSolver.h"

double u(const double x, const double y) {
	return std::sin(M_PI * x) * std::sin(2.0 * M_PI * y);
	//return -250*(x*x + y*y);
	//return std::sin(x + 1) * std::sin(y + 1);
};

double f(const double x, const double y) {
	return 5.0 * M_PI * M_PI * u(x, y);
	//return 1000;
	//return 2.0 * u(x, y);
};


int main(int argc, char *argv[])
{
	size_t Nx = 3;
	size_t Ny = 4;
	size_t itr_time = 10;
	if (argc > 3)
	{
		Nx = atoi(argv[1]);
		Ny = atoi(argv[2]);
		itr_time = atoi(argv[3]);
	}
	int nx = 1 << Nx;
	int ny = 1 << Ny;
	size_t N = Nx > Ny ? Ny : Nx;
	P1Element<double> ref_ele;
	double e = 0.0000000001;
	ref_ele.read_quad_info("data/triangle.tmp_geo", 1);

	int n_quad_pnts = ref_ele.get_n_quad_points();
	double volume = ref_ele.get_volume();

	Point<2, double> x0y0({0.0, 0.0});
	Point<2, double> x1y1({1.0, 1.0});
	double hx = (x1y1[0] - x0y0[1]) / nx;
	double hy = (x1y1[1] - x0y0[1]) / ny;
	size_t n_ele = nx*ny*2;
	size_t n_dof = ref_ele.get_n_dofs();
	size_t n_total_dofs = (nx + 1)*(ny + 1);
	std::valarray<int> dof(3);
	size_t ix, iy;
	std::vector<double> F(n_total_dofs, 0);
	std::vector<double> U(n_total_dofs, 0);

	//std::vector<double> M(n_total_dofs*n_total_dofs, 0);
	for(size_t ei = 0; ei < n_ele; ei++) {	
		if(ei%2 == 0) {
			ix = (ei/2)%nx;
			iy = (ei/2)/nx;
			Point<2, double> v0({ix*hx, iy*hy});
			Point<2, double> v1({(ix+1)*hx, iy*hy});
			Point<2, double> v2({ix*hx, (iy + 1)*hy});

			dof[0] = iy*(nx + 1) + ix;
			dof[1] = iy*(nx + 1) + ix + 1;
			dof[2] = (iy + 1)*(nx + 1) + ix;

			v0.set_index(dof[0]);
			v1.set_index(dof[1]);
			v2.set_index(dof[2]);

			ref_ele.set_global_dof(0, v0);
			ref_ele.set_global_dof(1, v1);
			ref_ele.set_global_dof(2, v2);
		}
		else {
			ix = ((ei - 1)/2)%nx;
			iy = ((ei - 1)/2)/nx;
			Point<2, double> v0({(ix + 1)*hx, iy*hy});
			Point<2, double> v1({(ix +1 )*hx, (iy + 1)*hy});
			Point<2, double> v2({ix*hx, (iy + 1)*hy});

			dof[0] = iy*(nx + 1) + ix + 1;
			dof[1] = (iy + 1)*(nx + 1) + ix + 1;
			dof[2] = (iy + 1)*(nx + 1) + ix;

			v0.set_index(dof[0]);
			v1.set_index(dof[1]);
			v2.set_index(dof[2]);

			ref_ele.set_global_dof(0, v0);
			ref_ele.set_global_dof(1, v1);
			ref_ele.set_global_dof(2, v2);
		}
		//std::cout << ix << iy << std::endl;
		int n_quad_pnts = ref_ele.get_n_quad_points();
		const std::valarray<Point<2, double> > &quad_pnts = ref_ele.get_quad_points();
		const std::valarray<double> &weights = ref_ele.get_quad_weights();
		double volume = ref_ele.get_volume();

		for(size_t l = 0; l != n_quad_pnts; l++) {
			double detJ = ref_ele.global2local_Jacobi_det(quad_pnts[l]);
			Matrix<double> invJ = ref_ele.local2global_Jacobi(quad_pnts[l]);

			for(size_t i = 0; i != n_dof; i++) {
				std::valarray<double> grad_i = invJ*ref_ele.basis_gradient(i, quad_pnts[l]);
				double phi_i = ref_ele.basis_function(i, quad_pnts[l]);
				size_t index = iy*(nx + 1) + ix;
				if(ei%2 == 0) {
					if(i == 1)
						index += 1;
					else if(i == 2)
						index += nx + 1;
				}
				else {
					if(i == 0)
						index += 1;
					else if(i == 1)
						index += nx + 2;
					else
						index += nx + 1;
				}
				F[index] += f(ref_ele.local2global(quad_pnts[l])[0], ref_ele.local2global(quad_pnts[l])[1])*phi_i*weights[l]*detJ*volume;
			}
		}
	}
	for(size_t i = 0; i != ny + 1; i++) {
		if(i == 0 || i == ny)
			for(size_t j = 0; j != nx + 1; j++) {
				U[i*(nx + 1) + j] = u(1.0*j/nx, 1.0*i/ny);
			}
		else {
			U[i*(nx + 1)] = u(0, 1.0*i/ny);
			U[i*(nx + 1) + nx] = u(1, 1.0*i/ny);
		}
	}	

	FMG_2d Solver;
//	std::cout << u(1.0, 1.0);
	Solver.Solve(U, F, Nx, Ny, 2, itr_time, 2, 2, e);
//	std::cout << "Solution:" << std::endl;
//	for(size_t i = 0; i != ny + 1; i++) {
//		for(size_t j = 0; j != nx + 1; j++) {
//			std::cout << Solver.get_U(i*(nx + 1) + j, N) << '\t' << u(1.0*j/nx, 1.0*i/ny)
//					  << '\t' << -1.0*f(1.0*j/nx, 1.0*i/ny) << std::endl;
//		}	
//	}
	for(size_t i = 0; i != n_total_dofs; i++) {
		U[i] = Solver.get_U(i, N);
	}
	//L2 error
	ref_ele.read_quad_info("data/triangle.tmp_geo", 4);
	double error = 0;
	const std::valarray<Point<2, double> > &quad_pnts = ref_ele.get_quad_points();
	const std::valarray<double> &weights = ref_ele.get_quad_weights();
	n_quad_pnts = ref_ele.get_n_quad_points();
	for(size_t ei = 0; ei < n_ele; ei++) {	
		if(ei%2 == 0) {
			ix = (ei/2)%nx;
			iy = (ei/2)/nx;
			Point<2, double> v0({ix*hx, iy*hy});
			Point<2, double> v1({(ix+1)*hx, iy*hy});
			Point<2, double> v2({ix*hx, (iy + 1)*hy});

			dof[0] = iy*(nx + 1) + ix;
			dof[1] = iy*(nx + 1) + ix + 1;
			dof[2] = (iy + 1)*(nx + 1) + ix;

			v0.set_index(dof[0]);
			v1.set_index(dof[1]);
			v2.set_index(dof[2]);

			ref_ele.set_global_dof(0, v0);
			ref_ele.set_global_dof(1, v1);
			ref_ele.set_global_dof(2, v2);
		}
		else {
			ix = ((ei - 1)/2)%nx;
			iy = ((ei - 1)/2)/nx;
			Point<2, double> v0({(ix + 1)*hx, iy*hy});
			Point<2, double> v1({(ix +1 )*hx, (iy + 1)*hy});
			Point<2, double> v2({ix*hx, (iy + 1)*hy});

			dof[0] = iy*(nx + 1) + ix + 1;
			dof[1] = (iy + 1)*(nx + 1) + ix + 1;
			dof[2] = (iy + 1)*(nx + 1) + ix;

			v0.set_index(dof[0]);
			v1.set_index(dof[1]);
			v2.set_index(dof[2]);

			ref_ele.set_global_dof(0, v0);
			ref_ele.set_global_dof(1, v1);
			ref_ele.set_global_dof(2, v2);
		}
		//std::cout << ix << iy << std::endl;
		int n_quad_pnts = ref_ele.get_n_quad_points();
		const std::valarray<Point<2, double> > &quad_pnts = ref_ele.get_quad_points();
		const std::valarray<double> &weights = ref_ele.get_quad_weights();
		double volume = ref_ele.get_volume();

		for(size_t l = 0; l != n_quad_pnts; l++) {
			double detJ = ref_ele.global2local_Jacobi_det(quad_pnts[l]);
			Matrix<double> invJ = ref_ele.local2global_Jacobi(quad_pnts[l]);
			double u_sol = 0;

			for(size_t i = 0; i != n_dof; i++) {
				std::valarray<double> grad_i = invJ*ref_ele.basis_gradient(i, quad_pnts[l]);
				double phi_i = ref_ele.basis_function(i, quad_pnts[l]);
				
				size_t index = iy*(nx + 1) + ix;
				if(ei%2 == 0) {
					if(i == 1)
						index += 1;
					else if(i == 2)
						index += nx + 1;
				}
				else {
					if(i == 0)
						index += 1;
					else if(i == 1)
						index += nx + 2;
					else
						index += nx + 1;
				}
				u_sol += phi_i*Solver.get_U(index, N);
			}
			//std::cout << "ASD:\t" << "(" << ref_ele.local2global(quad_pnts[l])[0] << ',' << ref_ele.local2global(quad_pnts[l])[1]  <<")"
			//	<< u(ref_ele.local2global(quad_pnts[l])[0], ref_ele.local2global(quad_pnts[l])[1])
			//	<<"\t" << u_sol << std::endl;
			error += (u(ref_ele.local2global(quad_pnts[l])[0], ref_ele.local2global(quad_pnts[l])[1]) - u_sol) 
				     *(u(ref_ele.local2global(quad_pnts[l])[0], ref_ele.local2global(quad_pnts[l])[1]) - u_sol)
					 *weights[l]*detJ*volume;
		}
	}
	error = std::sqrt(error);
	std::cout << "L2 error:" << error << std::endl;
};
